# Server_API/app/api/schemas/chat_request_models.py
# Description: This code provides schema models for the /chat API endpoints
#
# Imports
import os
from typing import Optional, Dict, Any, Literal, Union, List

from dotenv import load_dotenv

#
# 3rd-party imports
from pydantic import BaseModel, ConfigDict, Field, HttpUrl, field_validator, model_validator

#
# Local Imports
from tldw_Server_API.app.core.config import load_comprehensive_config

#
#######################################################################################################################
#
# Functions:

# --- Pydantic Models for OpenAI Chat Completion Request ---
# Based on https://platform.openai.com/docs/api-reference/chat/create

DEFAULT_LLM_PROVIDER = "openai"
model_config = ConfigDict(extra="allow", from_attributes=True)

# Config Loading
# Load standard .env, then also try uppercase .ENV for environments that use it.
try:
    load_dotenv()
    # Attempt to load from an uppercase filename as well (non-standard but requested)
    load_dotenv(dotenv_path=".ENV", override=False)
except Exception:
    # Fall back silently if dotenv loading fails; environment may be pre-populated
    pass

# Use load_and_log_configs which returns a proper dict
from tldw_Server_API.app.core.config import load_and_log_configs

_config = load_and_log_configs() or {}


def _config_default_llm_provider(config_data: Optional[Dict[str, Any]]) -> Optional[str]:
    if not isinstance(config_data, dict):
        return None
    for section in ("llm_api_settings", "API"):
        section_data = config_data.get(section)
        if isinstance(section_data, dict):
            default_api = section_data.get("default_api")
            if isinstance(default_api, str):
                candidate = default_api.strip()
                if candidate:
                    return candidate
    return None


_cfg_default_provider = _config_default_llm_provider(_config)
_env_default_provider = os.getenv("DEFAULT_LLM_PROVIDER")
_test_mode_enabled = os.getenv("TEST_MODE", "").lower() in ("1", "true", "yes")

if _cfg_default_provider:
    DEFAULT_LLM_PROVIDER = _cfg_default_provider
elif _env_default_provider:
    DEFAULT_LLM_PROVIDER = _env_default_provider
elif _test_mode_enabled:
    DEFAULT_LLM_PROVIDER = "local-llm"


def _get_setting(env_var, section, key, default=""):
    env_value = os.getenv(env_var)
    if env_value is not None:
        return env_value

    # Check for API key in the config dict
    if section == "api_keys":
        # Look for provider-specific API config like 'openai_api'
        provider_api_key = f"{key}_api"
        if provider_api_key in _config:
            api_config = _config[provider_api_key]
            if isinstance(api_config, dict):
                api_key = api_config.get("api_key")
                if api_key:
                    return api_key

    # Fallback to checking section directly
    config_section = _config.get(section)
    if config_section:
        config_value = config_section.get(key) if isinstance(config_section, dict) else None
        if config_value is not None:
            return config_value

    return default


ALL_SUPPORTED_PROVIDER_NAMES_LIST: List[str] = [
    "bedrock",
    "anthropic",
    "cohere",
    "deepseek",
    "google",
    "groq",
    "qwen",
    "huggingface",
    "mistral",
    "openai",
    "openrouter",
    "llama.cpp",
    "kobold",
    "ollama",
    "ooba",
    "tabbyapi",
    "vllm",
    "local-llm",
    "custom-openai-api",
    "custom-openai-api-2",
]


def get_api_keys() -> Dict[str, Optional[str]]:
    """
    Get API keys dynamically to support runtime changes.
    This function reloads config and environment variables each time it's called,
    ensuring that test environment changes are properly reflected.
    """
    # Reload config to get latest values
    current_config = load_and_log_configs() or {}

    def _get_dynamic_setting(env_var, section, key, default=""):
        # First check environment variable
        env_value = os.getenv(env_var)
        if env_value is not None:
            return env_value

        # Check for API key in the config dict
        if section == "api_keys":
            # Look for provider-specific API config like 'openai_api'
            provider_api_key = f"{key}_api"
            if provider_api_key in current_config:
                api_config = current_config[provider_api_key]
                if isinstance(api_config, dict):
                    api_key = api_config.get("api_key")
                    if api_key:
                        return api_key

        # Fallback to checking section directly
        config_section = current_config.get(section)
        if config_section:
            config_value = config_section.get(key) if isinstance(config_section, dict) else None
            if config_value is not None:
                return config_value

        return default

    return {
        name: _get_dynamic_setting(f"{name.upper().replace('.', '_')}_API_KEY", "api_keys", name)
        for name in ALL_SUPPORTED_PROVIDER_NAMES_LIST
    }


# Keep API_KEYS for backward compatibility but make it compute on first access.
# This map is for test overrides and transitional compatibility; new code should
# use get_api_keys() / resolve_provider_api_key.
API_KEYS = get_api_keys()

# For type hinting - define explicitly
SUPPORTED_API_ENDPOINTS = Literal[
    "bedrock",
    "anthropic",
    "cohere",
    "deepseek",
    "google",
    "groq",
    "qwen",
    "huggingface",
    "mistral",
    "openai",
    "openrouter",
    "llama.cpp",
    "kobold",
    "ollama",
    "ooba",
    "tabbyapi",
    "vllm",
    "local-llm",
    "custom-openai-api",
    "custom-openai-api-2",
]


# --- Tool Definitions ---
class FunctionDefinition(BaseModel):
    """Describes a function available to the model."""

    name: str = Field(
        ...,
        description="The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64.",
    )
    description: Optional[str] = Field(
        None,
        description="A description of what the function does, used by the model to choose when and how to call the function.",
    )
    parameters: Optional[Dict[str, Any]] = Field(
        default_factory=dict,
        description="The parameters the functions accepts, described as a JSON Schema object. See the guide[1] for examples, and the JSON Schema reference[2] for documentation about the format. Omitting parameters defines a function with an empty parameter list. [1] https://platform.openai.com/docs/guides/function-calling [2] https://json-schema.org/understanding-json-schema/",
    )


class ToolDefinition(BaseModel):
    """Definition of a tool the model can use."""

    type: Literal["function"] = Field(..., description="The type of the tool. Currently, only `function` is supported.")
    function: FunctionDefinition


class ToolChoiceFunction(BaseModel):
    """Specifies a specific function to be called."""

    name: str = Field(..., description="The name of the function to call.")


class ToolChoiceOption(BaseModel):
    """Specifies a tool the model should use. Use to force the model to call a specific function."""

    type: Literal["function"] = Field(..., description="The type of the tool. Currently, only `function` is supported.")
    function: ToolChoiceFunction


# --- Message Definitions ---
class ChatCompletionRequestMessageContentPartText(BaseModel):
    type: Literal["text"]
    text: str


class ChatCompletionRequestMessageContentPartImageURL(BaseModel):
    url: Union[HttpUrl, str] = Field(
        ...,
        description=(
            "Base64-encoded image data as a data URI (e.g., 'data:image/png;base64,...'). "
            "External HTTP/HTTPS URLs are not supported for chat images."
        ),
    )
    detail: Optional[Literal["auto", "low", "high"]] = Field(
        "auto", description="Specifies the detail level of the image."
    )

    @field_validator("url")
    def check_url_or_data(cls, v):
        # Normalize HttpUrl to string and enforce data URI requirement
        if isinstance(v, HttpUrl):
            v = str(v)
        if not isinstance(v, str):
            raise ValueError("url must be a string data URI")
        if not v.startswith("data:image"):
            raise ValueError("url must be a data URI for base64 encoded images")
        return v


class ChatCompletionRequestMessageContentPartImage(BaseModel):
    type: Literal["image_url"]
    image_url: ChatCompletionRequestMessageContentPartImageURL


# Content Part Union
ChatCompletionRequestMessageContentPart = Union[
    ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage
]


# Base Message Model
class BaseMessage(BaseModel):
    role: Literal["system", "user", "assistant", "tool"]
    name: Optional[str] = Field(
        None,
        description="An optional name for the participant. Provides the model information to differentiate between participants of the same role.",
    )


# Specific Message Types
class ChatCompletionSystemMessageParam(BaseMessage):
    role: Literal["system"]
    content: str


class ChatCompletionUserMessageParam(BaseMessage):
    role: Literal["user"]
    content: Union[str, List[ChatCompletionRequestMessageContentPart]]


class FunctionCall(BaseModel):
    """
    Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.
    """

    arguments: str = Field(
        ..., description="The arguments to call the function with, as generated by the model in JSON format."
    )
    name: str = Field(..., description="The name of the function to call.")


class ChatCompletionMessageToolCallParam(BaseModel):
    id: str = Field(..., description="The ID of the tool call.")
    type: Literal["function"] = Field(..., description="The type of the tool. Currently, only `function` is supported.")
    function: FunctionDefinition = Field(..., description="The function that the model called.")


class ChatCompletionAssistantMessageParam(BaseMessage):
    role: Literal["assistant"]
    content: Optional[str] = Field(
        None,
        description="The contents of the assistant message. Required unless tool_calls or function_call is specified.",
    )
    tool_calls: Optional[List[ChatCompletionMessageToolCallParam]] = Field(
        None, description="The tool calls generated by the model, such as function calls."
    )
    # Deprecated function_call - include for compatibility if needed, but prefer tool_calls
    function_call: Optional[FunctionCall] = Field(
        None, deprecated=True, description="Deprecated and replaced by `tool_calls`."
    )

    @model_validator(mode="before")
    def check_content_or_tool_call(cls, values):
        content = values.get("content")
        tool_calls = values.get("tool_calls")
        function_call = values.get("function_call")  # Include deprecated field check if necessary
        if content is None and not tool_calls and not function_call:
            raise ValueError("Assistant message must have content or tool_calls (or deprecated function_call)")
        return values


class ChatCompletionToolMessageParam(BaseMessage):
    role: Literal["tool"]
    content: str = Field(..., description="The contents of the tool message.")
    tool_call_id: str = Field(..., description="Tool call that this message is responding to.")


# Message Union
ChatCompletionMessageParam = Union[
    ChatCompletionSystemMessageParam,
    ChatCompletionUserMessageParam,
    ChatCompletionAssistantMessageParam,
    ChatCompletionToolMessageParam,
]


# --- Response Format ---
class ResponseFormat(BaseModel):
    type: Literal["text", "json_object"] = Field("text", description="Must be one of `text` or `json_object`.")


# --- Main Request Model ---
class ChatCompletionRequest(BaseModel):
    """
    Model representing the request body for the /v1/chat/completions endpoint.
    Acts as a proxy request, routing to different LLM providers.
    """

    # --- Routing Parameter ---
    api_provider: Optional[SUPPORTED_API_ENDPOINTS] = Field(
        default=None,  # Default is handled server-side
        description=f"[Extension] The target LLM provider (e.g., 'openai', 'anthropic'). If omitted, defaults to the server's configured default ('{DEFAULT_LLM_PROVIDER}').",
    )

    # --- Standard OpenAI-like Parameters ---
    model: Optional[str] = Field(
        None, description="ID of the model to use. Specific model compatibility depends on the selected `api_provider`."
    )
    messages: List[ChatCompletionMessageParam] = Field(
        ..., description="A list of messages comprising the conversation so far.", min_length=1
    )
    frequency_penalty: Optional[float] = Field(
        None, ge=-2.0, le=2.0, description="Frequency penalty parameter (provider support varies)."
    )
    logit_bias: Optional[Dict[str, float]] = Field(None, description="Logit bias parameter (provider support varies).")
    logprobs: Optional[bool] = Field(
        False, description="Whether to return log probabilities (provider support varies)."
    )
    top_logprobs: Optional[int] = Field(
        None,
        ge=0,
        le=20,
        description="Number of top log probabilities to return (provider support varies). `logprobs` must be true.",
    )
    max_tokens: Optional[int] = Field(
        None, ge=1, description="Maximum number of tokens to generate (provider support varies)."
    )
    n: Optional[int] = Field(
        1, ge=1, le=128, description="Number of completions to generate (provider support varies)."
    )
    presence_penalty: Optional[float] = Field(
        None, ge=-2.0, le=2.0, description="Presence penalty parameter (provider support varies)."
    )
    response_format: Optional[ResponseFormat] = Field(
        None, description="Response format specification (e.g., JSON mode, provider support varies)."
    )
    seed: Optional[int] = Field(None, description="Seed for deterministic sampling (provider support varies).")
    stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequences (provider support varies).")
    stream: Optional[bool] = Field(False, description="Whether to stream the response.")
    temperature: Optional[float] = Field(
        None, ge=0.0, le=2.0, description="Sampling temperature (provider support varies)."
    )
    top_p: Optional[float] = Field(
        None, ge=0.0, le=1.0, description="Top-p (nucleus) sampling parameter (provider support varies)."
    )
    tools: Optional[List[ToolDefinition]] = Field(
        None, max_length=128, description="Tools the model may call (provider support varies)."
    )
    tool_choice: Optional[Union[Literal["none", "auto", "required"], ToolChoiceOption]] = Field(
        "auto", description="Controls tool usage (provider support varies)."
    )
    user: Optional[str] = Field(None, description="End-user identifier for monitoring.")

    # --- Slash Commands Injection Override ---
    slash_command_injection_mode: Optional[Literal["system", "preface", "replace"]] = Field(
        None,
        description="[Extension] Override the server's slash command injection behavior for this request."
        " Options: 'system' (default server behavior), 'preface', or 'replace'.",
    )

    # --- Conversation history controls ---
    history_message_limit: Optional[int] = Field(
        None, ge=1, le=500, description="Optional override for the number of previous messages to load into context."
    )
    history_message_order: Optional[Literal["asc", "desc"]] = Field(
        None, description="Optional override for history ordering: 'asc' for oldest first, 'desc' for newest first."
    )

    # --- Bedrock Guardrails Extensions ---
    extra_headers: Optional[Dict[str, str]] = Field(
        None,
        description="Provider-specific additional headers to include via the request body (e.g., Bedrock guardrails)."
        " For Bedrock, include keys: 'X-Amzn-Bedrock-GuardrailIdentifier',"
        " 'X-Amzn-Bedrock-GuardrailVersion', and optional 'X-Amzn-Bedrock-Trace'.",
    )
    extra_body: Optional[Dict[str, Any]] = Field(
        None,
        description="Provider-specific extra body content. For Bedrock guardrails, include"
        " 'amazon-bedrock-guardrailConfig': { 'tagSuffix': '...'} if needed.",
    )

    # --- Extended Parameters for chat_api_call ---
    minp: Optional[float] = Field(None, description="[Extension] Minimum probability threshold (provider specific).")
    topk: Optional[int] = Field(None, description="[Extension] Top-K sampling parameter (provider specific).")
    # topp: Optional[float] = Field(None, description="[Extension] Explicit Top-P if needed separately from top_p.") # Uncomment if needed

    # --- Prompt templating ---
    prompt_template_name: Optional[str] = Field(
        None,
        description="Name of the prompt template to apply. Must contain only alphanumeric characters, underscores, and hyphens.",
    )

    # --- Optional Character Chat Parameters ---
    character_id: Optional[str] = Field(None, description="Optional ID of the character to use for context.")
    conversation_id: Optional[str] = Field(None, description="Optional ID of the conversation to use for context.")
    save_to_db: Optional[bool] = Field(
        None,
        description=(
            "[Extension] If true, persist conversation and messages to the database. "
            "If omitted, the server uses its configured default (see Chat-Module.chat_save_default/default_save_to_db)."
        ),
    )
    model_config = ConfigDict(
        extra="allow",
        json_schema_extra={
            "example": {
                "model": "gpt-4o-mini",
                "messages": [{"role": "user", "content": "Summarize the key points of this project."}],
                "stream": False,
                "api_provider": "openai",
            }
        },
    )

    @field_validator("prompt_template_name")
    @classmethod
    def validate_template_name(cls, v: Optional[str]) -> Optional[str]:
        """Validate prompt template name to prevent path traversal attacks."""
        if v is None:
            return v

        import re

        # Only allow alphanumeric, underscore, and hyphen
        if not re.match(r"^[a-zA-Z0-9_-]+$", v):
            raise ValueError(
                f"Invalid template name format. Only alphanumeric characters, underscores, and hyphens are allowed."
            )

        # Additional security check for path traversal patterns
        if "/" in v or "\\" in v or ".." in v:
            raise ValueError(f"Invalid template name. Path traversal patterns are not allowed.")

        return v

    @model_validator(mode="before")
    def check_logprobs(cls, values):
        logprobs = values.get("logprobs")
        top_logprobs = values.get("top_logprobs")
        if top_logprobs is not None and not logprobs:
            raise ValueError("If top_logprobs is specified, logprobs must be set to true.")
        return values


#
# End of chat_request_schemas.py
#######################################################################################################################
